{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "# 实现逻辑回归"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn import datasets\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "iris = datasets.load_iris()\n",
    "X = iris.data\n",
    "y = iris.target"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((100, 2), (100,))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 我们知道逻辑回归是解决2分类问题的，但鸢尾花数据集有4个分类，所以我们需要只取其中2个分类\n",
    "# 另外为了可视化，我们只取其中2个特征而不是全部特征\n",
    "X = X[y<2,:2]\n",
    "y = y[y<2]\n",
    "X.shape, y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD8CAYAAACMwORRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvAOZPmwAAFzRJREFUeJzt3X2MXFd5x/Hf45kUMG+RyAqi+GUrgagAhRCvQigIhdhUIVjmD6iaaikNauXiDSW0VLw0UqpaQqhCokDBRiujKqndEhqgDSilDYEW+gep1iEJBNMqUDuJS5uNKUlTt6lsP/3j3sW7s7Mz98zMmTnnzPcjXe3MnZO7z7n3+sndc597xtxdAICybJp0AACA0SO5A0CBSO4AUCCSOwAUiOQOAAUiuQNAgUjuAFAgkjsAFIjkDgAFajdtaGYtSUuSTrr77o7Prpf0UUkn61WfcvdDvbZ30UUX+ezsbFCwADDtjh49+ri7z/Rr1zi5S7pR0jFJz9vg89vc/d1NNzY7O6ulpaWAXw8AMLMTTdo1GpYxsy2S3iyp59U4ACANTcfcPy7p/ZLO9WjzVjN7wMxuN7Ot3RqY2V4zWzKzpeXl5dBYAQAN9U3uZrZb0mPufrRHsy9LmnX3SyXdJemWbo3cfdHd59x9bmam75ARAGBATa7cXytpj5kdl/Q5SVeb2eHVDdz9lLs/Xb89JGnHSKMEAATpm9zd/UPuvsXdZyVdJ+nr7v721W3M7OJVb/eouvEKAJiQkGqZNcxsv6Qld79D0nvMbI+kM5J+Iun60YQHABhE0ENM7v73KzXu7n5zndhXru5f7u6vdPc3uPsPYgQLTMSRI9LsrLRpU/XzyJFJRwT0NfCVOzAVjhyR9u6VTp+u3p84Ub2XpPn5ycUF9MH0A0AvN910PrGvOH26Wg8kjOQO9PLww2HrgUSQ3IFetm0LWw8kguQO9PLhD0ubN69dt3lztR5IGMkd6GV+XlpclLZvl8yqn4uL3ExF8qiWAfqZnyeZIztcuQNAgUjuAFAgkjsAFIjkDgAFIrkDQIFI7gBQIJI7ABSI5A4ABSK5A0CBSO4oB1+qAfwM0w+gDHypBrAGV+4oA1+qAaxBckcZ+FINYA2SO8rAl2oAa5DcUQa+VANYg+SOMvClGsAaVMugHHypBvAzXLljeNSXA8nhyh3Dob4cSBJX7hgO9eVAkkjuGA715UCSSO4YDvXlQJJI7hgO9eVAkkjuGA715UCSGlfLmFlL0pKkk+6+u+OzZ0i6VdIOSack/Yq7Hx9hnEgZ9eVAckKu3G+UdGyDz35D0n+6+4sl/bGkPxo2MCBL1PwjEY2Su5ltkfRmSYc2aPIWSbfUr2+XtNPMbPjwgIys1PyfOCG5n6/5J8FjAppeuX9c0vslndvg80skPSJJ7n5G0hOSXjB0dEBOqPlHQvomdzPbLekxdz867C8zs71mtmRmS8vLy8NuDkgLNf9ISJMr99dK2mNmxyV9TtLVZna4o81JSVslyczakp6v6sbqGu6+6O5z7j43MzMzVOBAcqj5R0L6Jnd3/5C7b3H3WUnXSfq6u7+9o9kdkn69fv22uo2PNFIgddT8IyED17mb2X4z21O//aykF5jZQ5J+V9IHRxEckBVq/pEQm9QF9tzcnC8tLU3kdwNArszsqLvP9WvHE6pI18KC1G5XV8HtdvUeQCPM5440LSxIBw+ef3/27Pn3Bw5MJiYgI1y5I02Li2HrAaxBckeazp4NWw9gDZI70tRqha0HsAbJHWla+R7WpusBrMENVaRp5abp4mI1FNNqVYmdm6lAIyR3pOvAAZI5MCCGZdDdrl1VffnKsmvXpCOaHOZoR4ZI7lhv1y7p7rvXrrv77ulM8MzRjkwx/QDW6/U9K9M2H9zsbJXQO23fLh0/Pu5oAKYfAEaCOdqRKZI70AtztCNTJHest3Nn2PqSMUc7MkVyx3pf+9r6RL5zZ7V+2jBHOzLFDVUAyAg3VDGcWLXdIdulvhwYGE+oYr2V2u7Tp6v3K7Xd0nDDESHbjRUDMCUYlsF6sWq7Q7ZLfTnQFcMyGFys2u6Q7VJfDgyF5I71YtV2h2yX+nJgKCR3rBertjtku9SXA0MhuWO9WLXdIdulvhwYCjdUASAj3FCNLcca7BxjBjAQ6twHkWMNdo4xAxgYwzKDyLEGO8eYAazDsExMOdZg5xgzgIGR3AeRYw12jjEDGBjJfRA51mDnGDOAgZHcB5FjDXaOMQMYWN8bqmb2TEnflPQMVdU1t7v7H3S0uV7SRyWdrFd9yt0P9dpu1jdUAWBCRnlD9WlJV7v7KyVdJukaM7uyS7vb3P2yeumZ2DEhCwtSu11dubfb1ftRtE2lfj6VOIAE9K1z9+rS/qn67QX1Mpn6SQxuYUE6ePD8+7Nnz78/cGDwtqnUz6cSB5CIRnXuZtaSdFTSiyV92t0/0PH59ZI+ImlZ0r9I+h13f6TXNhmWGbN2u0rSnVot6cyZwdumUj+fShxAZCOtc3f3s+5+maQtkq4ws1d0NPmypFl3v1TSXZJu2SCovWa2ZGZLy8vLTX41RqVbst5ofUjbVOrnU4kDSERQtYy7/1TSNyRd07H+lLs/Xb89JGnHBv/9orvPufvczMzMIPFiUK1W8/UhbVOpn08lDiARfZO7mc2Y2YX162dJeqOkH3S0uXjV2z2Sjo0ySIzAyvhzk/UhbVOpn08lDiAV7t5zkXSppO9IekDS9yTdXK/fL2lP/fojkh6UdL+qK/tf6LfdHTt2OMZs3z73Vstdqn7u2zeatocPu2/f7m5W/Tx8eNSRN5NKHEBEkpa8T351dyYOA4CcMHFYbLFqqkPqy2NuO6R/Oe6LzFDCj2BNLu9jLFkPyxw+7L55czVksbJs3jz8MMC+fWu3ubL0GhKJse2Q/uW4LzITaxcjT2JYJqJYNdUh9eUxtx3Svxz3RWYo4cdqTYdlSO6D2LSpuoDqZCadOzf4ds02/mzY4xSy7ZD+5bgvMhNrFyNPjLnHFKumOqS+POa2Q/qX477IDCX8GATJfRCxaqpD6stjbjukfznui8xQwo+BNBmYj7FkfUPVPV5NdUh9ecxth/Qvx32RGUr4sULcUAWA8jDmjvVSqF1H1jgt8tF3PncUImS+c+ZGRxecFnlhWGZapFC7jqxxWqSBYRmsFTLfOXOjowtOi7yQ3KdFCrXryBqnRV5I7tMihdp1ZI3TIi8k92kxPy8tLlYDpGbVz8XF7nfCQtpianBa5IUbqgCQEW6orohVmBuy3VTmJadIOSmlH47S+xdiIvuiyWOsMZaxTD8QayLskO2mMi85k4InpfTDUXr/Qox6X4jpBxSvMDdku6nMS06RclJKPxyl9y/EqPcF87lL8SbCDtluKvOSMyl4Uko/HKX3L8So9wVj7lK8wtyQ7aYyLzlFykkp/XCU3r8Qk9oXZSf3WIW5IdtNZV5yipSTUvrhKL1/ISa2L5oMzMdYxjafe6yJsEO2m8q85EwKnpTSD0fp/Qsxyn0hbqgCQHkYc48thfr5XbuquzIry65do4kBKEisx0ySr+NvcnkfY8n6a/ZSqJ/fubN7/fzOncPFABQk1mMmk6zjF8MyEaVQP59KiSWQsFiPmUyyjp9hmZhiTWzNhNnASHVL7L3WN5XDP1WS+yBSqJ8H0Fesx0xy+KdKch9ECvXzO3d238ZG64EpFOsxkyzq+JsMzMdYsr6h6p5G/XznTVVupgLrxHrMZFJ1/OKGKgCUZ2Q3VM3smWb2T2Z2v5k9aGZ/2KXNM8zsNjN7yMzuMbPZwcJuILS4NPli1A4hRbmF74uY4cbczU3F7F9mhzpI4af96PS7tJdkkp5Tv75A0j2SruxosyDpM/Xr6yTd1m+7Aw3LhBaX5japdEhRbuH7Ima4MXdzUzH7l9mhDlL4ad+IGg7LBI2TS9os6V5Jr+5Y/7eSXlO/bkt6XPV0whstAyX37du7/6vcvn007SdtZWCwc2m11rctfF/EDDfmbm4qZv8yO9RBCj/tG2ma3BuNuZtZS9JRSS+W9Gl3/0DH59+TdI27P1q//2H9P4DHO9rtlbRXkrZt27bjRLenAHoJnRg5t0mlQx5MKnxfxAw35m5uKmb/MjvUQQo/7RsZ6UNM7n7W3S+TtEXSFWb2ikGCcvdFd59z97mZmZnwDYQWl+ZQjLpaSFFu4fsiZrgxd3NTMfuX2aEOUvhpP1JBde7u/lNJ35B0TcdHJyVtlSQza0t6vqRTowhwjdDi0iyKUVcJKcotfF/EDDfmbm4qZv8yO9RBCj/tR6vfuI2kGUkX1q+fJelbknZ3tLlBa2+ofr7fdgeucw8tLs1tUumQotzC90XMcGPu5qZi9i+zQx2k8NO+L41qzN3MLpV0i6SWqiv9z7v7fjPbX/+SO8zsmZL+TNKrJP1E0nXu/qNe26XOHQDCNR1zb/dr4O4PqEranetvXvX6fyX9cmiQAIA4yp9bZmqfYEAvIadFCqdQzAd3cntIK4XjkYUmYzcxlrHMLVPiEwwYWshpkcIpFPPBndwe0krheEyamFtGk51RH8kKOS1SOIVCY0ihf7ltNydNx9zLTu4lPsGAoYWcFimcQjEf3MntIa0Ujsek8U1M0nQ/wYANhZwWKZxCMR/cye0hrRSORy7KTu5T/QQDNhJyWqRwCsV8cCe3h7RSOB7ZaDIwH2MZ25d1lPYEA0Yi5LRI4RSK+eBObg9ppXA8JkncUAWA8jDmDoxIyBd7pCK3mFOpXU8ljpFocnkfY8n+O1QxFUK+2CMVucWcSu16KnH0I4ZlgOG129LZs+vXt1rSmTPjj6eJ3GJOpXY9lTj6YVgGGIFuSbLX+hTkFvPDD4etLz2OUSG5Az2EfLFHKnKLOZXa9VTiGBWSO9BDyBd7pCK3mFOpXU8ljpFpMjAfY+GGKnIR8sUeqcgt5lRq11OJoxdxQxUAysMNVYxNjrXBsWKOVV+e4z7GhDW5vI+xMCxThlxqg1eLFXOs+vIc9zHiEcMyGIdcaoNXixVzrPryHPcx4mFYBmORY21wrJhj1ZfnuI8xeSR3DCXH2uBYMceqL89xH2PySO4YSo61wbFijlVfnuM+RgKaDMzHWLihWo4caoM7xYo5Vn15jvsYcYgbqgBQHm6oYurEqgUP2S716EhFe9IBAKNw5Eg1tn36dPX+xInzY93z8+PZbqwYgEEwLIMixKoFD9ku9egYB4ZlMFVi1YKHbJd6dKSE5I4ixKoFD9ku9ehICckdRYhVCx6yXerRkRKSO4owPy8tLlbj22bVz8XF4W9khmw3VgzAIPreUDWzrZJulfRCSS5p0d0/0dHmKkl/Lelf61VfdPf9vbbLDVUACDfKG6pnJL3P3V8m6UpJN5jZy7q0+5a7X1YvPRM70pdjvTb16PGx3zLS5DHW1YuqK/Q3dqy7StJXQrbD9APpynH+8JCYc+xfCthvaVCM6QfMbFbSNyW9wt2fXLX+KklfkPSopH+T9Hvu/mCvbTEsk64c67WpR4+P/ZaGpsMyjZO7mT1H0j9I+rC7f7Hjs+dJOufuT5nZtZI+4e4v6bKNvZL2StK2bdt2nOh2pmDiNm2qrss6mUnnzo0/niZCYs6xfylgv6VhpA8xmdkFqq7Mj3Qmdkly9yfd/an69Z2SLjCzi7q0W3T3OXefm5mZafKrMQE51mtTjx4f+y0vfZO7mZmkz0o65u4f26DNi+p2MrMr6u2eGmWgGJ8c67WpR4+P/ZaZfoPykl6nqgTyAUn31cu1kt4l6V11m3dLelDS/ZK+LekX+22XG6ppy3H+8JCYc+xfCthvkyfmcweA8jBx2BSg5nithQWp3a5u8LXb1XtgWjGfe6aYO3ythQXp4MHz78+ePf/+wIHJxARMEsMymaLmeK12u0ronVot6cyZ8ccDxMKwTOGYO3ytbom913qgdCT3TFFzvFarFbYeKB3JPVPUHK+1cr+h6XqgdCT3TDF3+FoHDkj79p2/Um+1qvfcTMW04oYqAGSEG6qDKLxwvPDuFd+/FLCPM9LkMdYYS3LTDxQ+WXXh3Su+fylgH6dBTD8QqPDC8cK7V3z/UsA+TsPI53MfteSSe+GTVRfeveL7lwL2cRoYcw9VeOF44d0rvn8pYB/nheS+ovDC8cK7V3z/UsA+zgvJfUXhheOFd6/4/qWAfZwXxtwBICOMuQMFiVlfTu16mZjPHUhczLn7+V6AcjEsAyQuZn05tev5YVgGKETMufv5XoBykdyBxMWsL6d2vVwkdyBxMevLqV0vF8kdSFzM+nJq18vFDVUAyAg3VAFgipHcAaBAJHcAKBDJHQAKRHIHgAKR3AGgQCR3ACgQyR0ACtQ3uZvZVjP7hpl938weNLMbu7QxM/ukmT1kZg+Y2eVxwsUwmLcbmB5N5nM/I+l97n6vmT1X0lEzu8vdv7+qzZskvaReXi3pYP0TiWDebmC69L1yd/cfu/u99ev/knRM0iUdzd4i6VavfFvShWZ28cijxcBuuul8Yl9x+nS1HkB5gsbczWxW0qsk3dPx0SWSHln1/lGt/x+AzGyvmS2Z2dLy8nJYpBgK83YD06Vxcjez50j6gqT3uvuTg/wyd1909zl3n5uZmRlkExgQ83YD06VRcjezC1Ql9iPu/sUuTU5K2rrq/ZZ6HRLBvN3AdGlSLWOSPivpmLt/bINmd0h6R101c6WkJ9z9xyOME0Ni3m5gujSplnmtpF+T9F0zu69e9/uStkmSu39G0p2SrpX0kKTTkt45+lAxrPl5kjkwLfomd3f/R0nWp41LumFUQQEAhsMTqgBQIJI7ABSI5A4ABSK5A0CBSO4AUCCSOwAUiOQOAAWyqkR9Ar/YbFnSiYn88v4ukvT4pIOIiP7lq+S+SfSvie3u3ndyrokl95SZ2ZK7z006jljoX75K7ptE/0aJYRkAKBDJHQAKRHLvbnHSAURG//JVct8k+jcyjLkDQIG4cgeAAk11cjezlpl9x8y+0uWz681s2czuq5ffnESMwzCz42b23Tr+pS6fm5l90sweMrMHzOzyScQ5iAZ9u8rMnlh1/G6eRJyDMrMLzex2M/uBmR0zs9d0fJ7tsZMa9S/b42dmL10V931m9qSZvbejTfTj1+TLOkp2o6Rjkp63wee3ufu7xxhPDG9w943qat8k6SX18mpJB+ufuejVN0n6lrvvHls0o/UJSV9197eZ2c9J6viSxOyPXb/+SZkeP3f/Z0mXSdUFpKqvHP1SR7Pox29qr9zNbIukN0s6NOlYJugtkm71yrclXWhmF086qGlnZs+X9HpVX28pd/8/d/9pR7Nsj13D/pVip6QfunvnA5vRj9/UJndJH5f0fknnerR5a/0n0+1mtrVHu1S5pL8zs6NmtrfL55dIemTV+0frdTno1zdJeo2Z3W9mf2NmLx9ncEP6eUnLkv60HjY8ZGbP7miT87Fr0j8p3+O32nWS/qLL+ujHbyqTu5ntlvSYux/t0ezLkmbd/VJJd0m6ZSzBjdbr3P1yVX8C3mBmr590QCPUr2/3qnpM+5WS/kTSX407wCG0JV0u6aC7v0rSf0v64GRDGqkm/cv5+EmS6uGmPZL+chK/fyqTu6ov/d5jZsclfU7S1WZ2eHUDdz/l7k/Xbw9J2jHeEIfn7ifrn4+pGvO7oqPJSUmr/yLZUq9LXr++ufuT7v5U/fpOSReY2UVjD3Qwj0p61N3vqd/frioZrpbtsVOD/mV+/Fa8SdK97v4fXT6LfvymMrm7+4fcfYu7z6r6s+nr7v721W06xr/2qLrxmg0ze7aZPXfltaRfkvS9jmZ3SHpHfef+SklPuPuPxxxqsCZ9M7MXmZnVr69Qda6fGnesg3D3f5f0iJm9tF61U9L3O5pleeykZv3L+fit8qvqPiQjjeH4TXu1zBpmtl/SkrvfIek9ZrZH0hlJP5F0/SRjG8ALJX2p/vfRlvTn7v5VM3uXJLn7ZyTdKelaSQ9JOi3pnROKNVSTvr1N0j4zOyPpfyRd53k9sffbko7Uf9r/SNI7Czl2K/r1L+vjV190vFHSb61aN9bjxxOqAFCgqRyWAYDSkdwBoEAkdwAoEMkdAApEcgeAApHcAaBAJHcAKBDJHQAK9P9IUj1h6gimcQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x109124550>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 分类0的散点图\n",
    "plt.scatter(X[y==0,0], X[y==0,1], color='red')\n",
    "\n",
    "# 分类1的散点图\n",
    "plt.scatter(X[y==1,0], X[y==1,1], color='blue')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 使用我们自己编写的逻辑回归"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from playML.model_selection import train_test_split\n",
    "\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, seed=666)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LogisticRegression()"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from playML.logistic_regression import LogisticRegression\n",
    "\n",
    "log_reg = LogisticRegression()\n",
    "log_reg.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.0"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "log_reg.score(X_test, y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "评分结果不错，不过当然是因为我们的数据很简单"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 0.92972035,  0.98664939,  0.14852024,  0.17601199,  0.0369836 ,\n        0.0186637 ,  0.04936918,  0.99669244,  0.97993941,  0.74524655,\n        0.04473194,  0.00339285,  0.26131273,  0.0369836 ,  0.84192923,\n        0.79892262,  0.82890209,  0.32358166,  0.06535323,  0.20735334])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "log_reg.predict_proba(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "log_reg.predict(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
